import json
import pickle


def generate_filename(base_name, threshold):
    if threshold > 0:
        return f'{base_name}_{threshold}.json'
    else:
        return f'{base_name}.json'

with open('./ood_test_data.json', 'r') as fp:
    ood_test_data = json.load(fp)

ood_apis = set().union(*[sample['apis'] for sample in ood_test_data])
print('len of ood apis:', len(ood_apis))
with open('./synthetic_train_data.json', 'r') as fp:
    synthetic_train_data = json.load(fp)

threshold = -1

with open('./query_and_description_2.json', 'r') as fp:
    query_and_description_2 = json.load(fp)

with open('./dataset_split_keys.json', 'r') as fp:
    dataset_split = json.load(fp)

with open('../../data/statistics.pkl', 'rb') as fp:
    stat = pickle.load(fp)

test_keys = set(dataset_split['test'])
dev_keys = set(dataset_split['dev'])


synthesized_data = []
ALL_data = [] 

for sample in synthetic_train_data:
    if sample['pruned_code_score'] >= threshold:
        synthesized_data.append(
            {
                'line_by_line': sample['pruned_code'],
                'description': sample['pruned_thought'],
                'query': sample['query'],
                'apis': sample['pruned_apis'],
                'key': 'synthesized_training_data'
            }
        )


        ALL_data.append(
            {
                'line_by_line': sample['pruned_code'],
                'description': sample['pruned_thought'],
                'query': sample['query'],
                'apis': sample['pruned_apis'],
                'key': 'synthesized_training_data'
            }
        )
print('filtered_data length', len(synthesized_data))


for sample in query_and_description_2:
    if sample['key'] in test_keys or sample['key'] in dev_keys:
        synthesized_data.append(
            {
                'line_by_line': sample['line_by_line'],
                'description': sample['description'],
                'query': sample['query'],
                'apis': list(stat[sample['key']]['action_names']),
                'key': sample['key']
            }
        )

    ALL_data.append(
        {
            'line_by_line': sample['line_by_line'],
            'description': sample['description'],
            'query': sample['query'],
            'apis': list(stat[sample['key']]['action_names']),
            'key': sample['key']
        }
    )

for sample in ood_test_data:
    synthesized_data.append(
        {
            'line_by_line': sample['pruned_code'],
            'description': sample['pruned_thought'],
            'query': sample['query'],
            'apis': sample['pruned_apis'],
            'key': 'synthesized_ood_test_data'
        }
    )

    query_and_description_2.append(
        {
            'line_by_line': sample['pruned_code'],
            'description': sample['pruned_thought'],
            'query': sample['query'],
            'apis': sample['pruned_apis'],
            'key': 'synthesized_ood_test_data'
        }
    )

    ALL_data.append(
        {
            'line_by_line': sample['pruned_code'],
            'description': sample['pruned_thought'],
            'query': sample['query'],
            'apis': sample['pruned_apis'],
            'key': 'synthesized_ood_test_data'
        }
    )

with open(generate_filename('full_synthesized_data', threshold), 'w') as fp:
    json.dump(synthesized_data, fp, indent=4)

print(f'len of {generate_filename("full_synthesized_data", threshold )}', len(synthesized_data))


with open(generate_filename('full_synthesized_with_seed_data', threshold), 'w') as fp:
    json.dump(ALL_data, fp, indent=4)

print(f'len of {generate_filename("full_synthesized_with_seed_data", threshold )}', len(ALL_data))


with open(generate_filename('query_and_description_3', threshold), 'w') as fp:
    json.dump(query_and_description_2, fp, indent=4)


print(f'len of {generate_filename("query_and_description_3", threshold )}', len(query_and_description_2))